import copy
import gym
import numpy as np

from metaworld_complex.envs.sawyer_complex_v2 import SawyerComplexEnvV2
from metaworld_complex.policies import CustomSawyerReturnV2Policy
from diffgro.environments.metaworld_complex.policies import (
    make_metaworld_complex_policy,
)

##################################################

task_dict = {
    "puck": "push_puck",
    "drawer": "close_drawer",
    "button": "press_button",
    # "door": "open_door",
    "stick": "peg_insert",
}


class MetaWorldComplexEnv(gym.Env):
    def __init__(self, task_list, max_steps: int = 2000, seed: int = 777):
        self.domain_name = "metaworld_complex"
        self.task_list = task_list
        self.task_num = len(task_list)
        self.max_steps = max_steps
        self._seed = seed

        # initialize with task list
        self.initialize(task_list)

        self.full_task_list = []
        for task in task_list:
            self.full_task_list.append(task_dict[task])
        self.env_name = "_and_".join(self.full_task_list)

        self.observation_space = gym.spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=(4 + 17 * len(task_list) + 3,),
            dtype=np.float32,
        )
        # 4 + 68 + 3 = 75
        """ observation spec
            hand position   : [0:3]
            hand gripper    : [3]
            sub_task goals  : 17 * num_tasks 
        """
        self.action_space = gym.spaces.Box(
            self.env.action_space.low, self.env.action_space.high, dtype=np.float32
        )

    def reset(self, seed=None, warmup=False):
        if seed is not None:
            self._seed = seed
            self.env.seed(seed)

        self.timesteps = 0
        self.success_count = 0
        self.prev_action = np.zeros(self.action_space.shape)
        obs = self.env.reset()
        obs = np.concatenate((obs, copy.deepcopy(self.prev_action[:3])), axis=-1)
        return obs

    def step(self, action):
        skill = None
        if type(action) is tuple:
            action, skill = action

        action = np.clip(action, -1, 1)
        obs, rew, done, info = self.env.step(action)
        self.timesteps += 1

        # Done if env.success_count
        info["speed"] = np.linalg.norm(action[:3] - self.prev_action[:3])
        info["force"] = np.linalg.norm(action[:3])
        info["force_axis"] = np.sqrt(np.square(action[:3]))
        info["energy"] = np.sum(np.abs(action[:3]))
        rew = 1 if info["success"] else 0
        info["sub_goal_success"] = False
        if info["success"]:
            self.success_count += 1
            info["success"] = False
            info["sub_goal_success"] = True

        if self.success_count == len(self.task_list):
            self.env.mode -= 1  # To prevent index error on last _get_obs() call
            info["success"] = self.success_count / len(self.task_list)
            done = True

        if self.timesteps == self.max_steps:
            info["success"] = self.success_count / len(self.task_list)
            done = True

        self.prev_action = action
        obs = np.concatenate((obs, copy.deepcopy(self.prev_action[:3])), axis=-1)
        return obs, rew, done, info

    def render(self, offscreen=True, camera_name="corner3", resolution=(640, 480)):
        image = self.env.render(offscreen, camera_name, resolution)
        return image

    def initialize(self, task_list):
        self.env = SawyerComplexEnvV2(task_list)
        self.env._partially_observable = False

    def get_exp(self):
        exp = make_metaworld_complex_policy(self.env, variant=None)
        return exp


class MetaWorldComplexVariantEnv(MetaWorldComplexEnv):
    goal_threshold = [0.2, 0.5, 0.8]
    joint_name = {
        "handle": "handleslide",
        "button": "btnbox_joint",
        "drawer": "drawerslide",
        "lever": "LeverAxis",
        "door": "doorjoint",
    }

    def __init__(self, variant_space, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.variant_space = variant_space
        self.variant = None

        self.threshold = 1e9
        self.joint_id = None

    def _warmup(self):
        return_policy = CustomSawyerReturnV2Policy()
        for i in range(50):
            act = return_policy.get_action(self.env._last_stable_obs)
            obs, _, _, _ = self.step(act)

        self.timesteps = 0
        obs = np.concatenate((obs[:-3], np.zeros(self.action_space.shape)[:3]), axis=-1) 

        return obs

    def reset(self, *args, variant=None, warmup=False, **kwargs):
        super().reset(*args, **kwargs)

        self.variant = variant if variant is not None else self.variant_space.sample()
        self.reset_model()
        self.damping = False
        return self._warmup()

    def step(self, action):
        skill = None
        if type(action) is tuple:
            action, skill = action

        # [arm_speed] : no change
        # action[:3] = action[:3]
        # [wind_xspeed] : x axis change
        # action[0] = action[0] / self.wind_xspeed
        # [wind_yspeed] : y axis change
        # action[1] = action[1] / self.wind_yspeed

        # [resistance]
        if skill in ["push", "pull"]:
            current_task = self.task_list[self.env.mode]

            joint_id = self.env.sim.model.joint_name2id(self.joint_name[current_task])
            goal_threshold = self.goal_threshold[current_task]

            force = np.linalg.norm(action[:3])
            if force < goal_threshold:
                self.env.sim.model.dof_damping[joint_id] = 100000
            else:
                action[:3] = action[:3] / force * 0.3
                self.env.sim.model.dof_damping[joint_id] = 1

        damping_rew, damping_rew_com = None, None
        if self.damping:
            force = np.linalg.norm(action[:3])
            if self.variant["goal_resistance"]["drawer"] == 1:
                thresh = 0.65
            if self.variant["goal_resistance"]["drawer"] == 2:
                thresh = 0.6
            if force < thresh:
                action = action * 0.0  # no action applied
                damping_rew = 0.0
            else:
                damping_rew = 1.0
            print(self.timesteps, thresh, force, damping_rew)

        # environment step
        obs, rew, done, info = super().step(action)
        # # # # # # # # # #
        
        if self.timesteps < 2:
            self.obj_pos = obs[22] # drawer-close

        # [check when to resist]
        if (
            (skill is None)
            and (not self.damping)
            and (self.variant["goal_resistance"]["drawer"] >= 1)
            and (self.task_list[self.env.mode] == 'drawer')
        ):
            cur_obj_pos = obs[22]
            if cur_obj_pos > self.obj_pos:
                self.damping = True
                print(f"checking damping at {self.timesteps} to {self.damping}")
        # [reset damping]
        if (
            (self.variant["goal_resistance"]["drawer"] == 0)
            or (self.task_list[self.env.mode] != 'drawer')
        ):
            self.damping = False
            # print(f"damping set as {self.damping}")
        
        info["damping"] = self.damping
        if damping_rew is not None:
            info["damping_rew"] = damping_rew
        return obs, rew, done, info

    def get_exp(self):
        exp = make_metaworld_complex_policy(self.env, variant=self.variant)
        return exp

    def reset_model(self):

        if self.variant is not None:
            self._set_arm_speed(self.variant["arm_speed"])
            self._set_goal_resistance(self.variant["goal_resistance"])
            self._set_wind_xspeed(self.variant["wind_xspeed"])
            self._set_wind_yspeed(self.variant["wind_yspeed"])
        else:
            raise ValueError

    def update_variant_space(self, variant_space):
        for k, v in variant_space.variant_config.items():
            self.variant_space.variant_config[k] = v

    def _set_arm_speed(self, arm_speed):
        self.arm_speed = arm_speed

    def _set_wind_xspeed(self, wind_xspeed):
        self.wind_xspeed = wind_xspeed

    def _set_wind_yspeed(self, wind_yspeed):
        self.wind_yspeed = wind_yspeed

    def _set_goal_resistance(self, goal_resistance):
        self.goal_threshold = {
            k: self.__class__.goal_threshold[v] for k, v in goal_resistance.items()
        }
